{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "IcfrhafzkZbH"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qFdPvlXBOdUN"
      },
      "source": [
        "# Quantization aware training comprehensive guide"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/quantization/training_comprehensive_guide.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/quantization/training_comprehensive_guide.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/model-optimization/tensorflow_model_optimization/g3doc/guide/quantization/training_comprehensive_guide.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FbORZA_bQx1G"
      },
      "source": [
        "Welcome to the comprehensive guide for Keras quantization aware training.\n",
        "\n",
        "This page documents various use cases and shows how to use the API for each one. Once you know which APIs you need, find the parameters and the low-level details in the\n",
        "[API docs](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/quantization).\n",
        "\n",
        "*  If you want to see the benefits of quantization aware training and what's supported, see the [overview](https://www.tensorflow.org/model_optimization/guide/quantization/training.md).\n",
        "*  For a single end-to-end example, see the [quantization aware training example](https://www.tensorflow.org/model_optimization/guide/quantization/training_example.md).\n",
        "\n",
        "The following use cases are covered:\n",
        "\n",
        "* Deploy a model with 8-bit quantization with these steps.\n",
        "  * Define a quantization aware model.\n",
        "  * For Keras HDF5 models only, use special checkpointing and\n",
        "    deserialization logic. Training is otherwise standard.\n",
        "  * Create a quantized model from the quantization aware one.\n",
        "* Experiment with quantization.\n",
        "  * Anything for experimentation has no supported path to deployment.\n",
        "  * Custom Keras layers fall under experimentation."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nuABqZnXVDvO"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qqnbd7TOfAq9"
      },
      "source": [
        "For finding the APIs you need and understanding purposes, you can run but skip reading this section."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "lvpH1Hg7ULFz"
      },
      "outputs": [],
      "source": [
        "! pip uninstall -y tensorflow\n",
        "! pip install -q tf-nightly\n",
        "! pip install -q tensorflow-model-optimization\n",
        "\n",
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "import tensorflow_model_optimization as tfmot\n",
        "\n",
        "import tempfile\n",
        "\n",
        "input_shape = [20]\n",
        "x_train = np.random.randn(1, 20).astype(np.float32)\n",
        "y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)\n",
        "\n",
        "def setup_model():\n",
        "  model = tf.keras.Sequential([\n",
        "      tf.keras.layers.Dense(20, input_shape=input_shape),\n",
        "      tf.keras.layers.Flatten()\n",
        "  ])\n",
        "  return model\n",
        "\n",
        "def setup_pretrained_weights():\n",
        "  model= setup_model()\n",
        "\n",
        "  model.compile(\n",
        "      loss=tf.keras.losses.categorical_crossentropy,\n",
        "      optimizer='adam',\n",
        "      metrics=['accuracy']\n",
        "  )\n",
        "\n",
        "  model.fit(x_train, y_train)\n",
        "\n",
        "  _, pretrained_weights = tempfile.mkstemp('.tf')\n",
        "\n",
        "  model.save_weights(pretrained_weights)\n",
        "\n",
        "  return pretrained_weights\n",
        "\n",
        "def setup_pretrained_model():\n",
        "  model = setup_model()\n",
        "  pretrained_weights = setup_pretrained_weights()\n",
        "  model.load_weights(pretrained_weights)\n",
        "  return model\n",
        "\n",
        "setup_model()\n",
        "pretrained_weights = setup_pretrained_weights()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dTHLMLV-ZrUA"
      },
      "source": [
        "##Define quantization aware model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0U6XAUhIe6re"
      },
      "source": [
        "By defining models in the following ways, there are available paths to deployment to backends listed in the [overview page](https://www.tensorflow.org/model_optimization/guide/quantization/training.md). By default, 8-bit quantization is used.\n",
        "\n",
        "Note: a quantization aware model is not actually quantized. Creating a quantized model is a separate step."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ybigft1fTn4T"
      },
      "source": [
        "### Quantize whole model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "puZvqnp1xsn-"
      },
      "source": [
        "**Your use case:**\n",
        "* Subclassed models are not supported.\n",
        "\n",
        "**Tips for better model accuracy:**\n",
        "\n",
        "* Try \"Quantize some layers\" to skip quantizing the layers that reduce accuracy the most.\n",
        "* It's generally better to finetune with quantization aware training as opposed to training from scratch.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_Zhzx_azO1WR"
      },
      "source": [
        "To make the whole model aware of quantization, apply `tfmot.quantization.keras.quantize_model` to the model.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1s_EK8reOruu"
      },
      "outputs": [],
      "source": [
        "base_model = setup_model()\n",
        "base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy\n",
        "\n",
        "quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)\n",
        "quant_aware_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xTbTLn3dZM7h"
      },
      "source": [
        "### Quantize some layers"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MbM8o832xTxV"
      },
      "source": [
        "Quantizing a model can have a negative effect on accuracy. You can selectively quantize layers of a model to explore the trade-off between accuracy, speed, and model size.\n",
        "\n",
        "**Your use case:**\n",
        "* To deploy to a backend that only works well with fully quantized models (e.g. EdgeTPU v1, most DSPs), try \"Quantize whole model\".\n",
        "\n",
        "**Tips for better model accuracy:**\n",
        "* It's generally better to finetune with quantization aware training as opposed to training from scratch.\n",
        "* Try quantizing the later layers instead of the first layers.\n",
        "* Avoid quantizing critical layers (e.g. attention mechanism).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3OCbOUWHsE_v"
      },
      "source": [
        "In the example below, quantize only the `Dense` layers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HN0B_QB-ZhE2"
      },
      "outputs": [],
      "source": [
        "# Create a base model\n",
        "base_model = setup_model()\n",
        "base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy\n",
        "\n",
        "# Helper function uses `quantize_annotate_layer` to annotate that only the \n",
        "# Dense layers should be quantized.\n",
        "def apply_quantization_to_dense(layer):\n",
        "  if isinstance(layer, tf.keras.layers.Dense):\n",
        "    return tfmot.quantization.keras.quantize_annotate_layer(layer)\n",
        "  return layer\n",
        "\n",
        "# Use `tf.keras.models.clone_model` to apply `apply_quantization_to_dense` \n",
        "# to the layers of the model.\n",
        "annotated_model = tf.keras.models.clone_model(\n",
        "    base_model,\n",
        "    clone_function=apply_quantization_to_dense,\n",
        ")\n",
        "\n",
        "# Now that the Dense layers are annotated,\n",
        "# `quantize_apply` actually makes the model quantization aware.\n",
        "quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)\n",
        "quant_aware_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HiA28PrrW11H"
      },
      "source": [
        "While this example used the type of the layer to decide what to quantize, the easiest way to quantize a particular layer is to set its `name` property, and look for that name in the `clone_function`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CjY_JyB808Da"
      },
      "outputs": [],
      "source": [
        "print(base_model.layers[0].name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mpb_BydRaSoF"
      },
      "source": [
        "#### More readable but potentially lower model accuracy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2vqXeYffzSHp"
      },
      "source": [
        "This is not compatible with finetuning with quantization aware training, which is why it may be less accurate than the above examples."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MQoMH3g3fWwb"
      },
      "source": [
        "**Functional example**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7Wow55hg5oiM"
      },
      "outputs": [],
      "source": [
        "# Use `quantize_annotate_layer` to annotate that the `Dense` layer\n",
        "# should be quantized.\n",
        "i = tf.keras.Input(shape=(20,))\n",
        "x = tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(10))(i)\n",
        "o = tf.keras.layers.Flatten()(x)\n",
        "annotated_model = tf.keras.Model(inputs=i, outputs=o)\n",
        "\n",
        "# Use `quantize_apply` to actually make the model quantization aware.\n",
        "quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)\n",
        "\n",
        "# For deployment purposes, the tool adds `QuantizeLayer` after `InputLayer` so that the\n",
        "# quantized model can take in float inputs instead of only uint8.\n",
        "quant_aware_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wIGj-r2of2ls"
      },
      "source": [
        "**Sequential example**\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mQOiDUGgfi4y"
      },
      "outputs": [],
      "source": [
        "# Use `quantize_annotate_layer` to annotate that the `Dense` layer\n",
        "# should be quantized.\n",
        "annotated_model = tf.keras.Sequential([\n",
        "  tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=input_shape)),\n",
        "  tf.keras.layers.Flatten()\n",
        "])\n",
        "\n",
        "# Use `quantize_apply` to actually make the model quantization aware.\n",
        "quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)\n",
        "\n",
        "quant_aware_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MpvX5IqahV1r"
      },
      "source": [
        "## Checkpoint and deserialize"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GuZ5wlij1dcJ"
      },
      "source": [
        "**Your use case:** this code is only needed for the HDF5 model format (not HDF5 weights or other formats)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6khQg-q7imfH"
      },
      "outputs": [],
      "source": [
        "# Define the model.\n",
        "base_model = setup_model()\n",
        "base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy\n",
        "quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)\n",
        "\n",
        "# Save or checkpoint the model.\n",
        "_, keras_model_file = tempfile.mkstemp('.h5')\n",
        "quant_aware_model.save(keras_model_file)\n",
        "\n",
        "# `quantize_scope` is needed for deserializing HDF5 models.\n",
        "with tfmot.quantization.keras.quantize_scope():\n",
        "  loaded_model = tf.keras.models.load_model(keras_model_file)\n",
        "\n",
        "loaded_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NeNCMDAbnEKU"
      },
      "source": [
        "## Create and deploy quantized model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iiYk_KR0rJ2n"
      },
      "source": [
        "In general, reference the documentation for the deployment backend that you\n",
        "will use.\n",
        "\n",
        "This is an example for the TFLite backend."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fbBiEetda3R8"
      },
      "outputs": [],
      "source": [
        "base_model = setup_pretrained_model()\n",
        "quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)\n",
        "\n",
        "# Typically you train the model here.\n",
        "\n",
        "converter = tf.lite.TFLiteConverter.from_keras_model(quant_aware_model)\n",
        "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
        "\n",
        "quantized_tflite_model = converter.convert()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v5raSy9ghxkv"
      },
      "source": [
        "## Experiment with quantization"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LUGpXIET0cy3"
      },
      "source": [
        "**Your use case**: using the following APIs means that there is no\n",
        "supported path to deployment. The features are also experimental and not\n",
        "subject to backward compatibility.\n",
        "  * `tfmot.quantization.keras.QuantizeConfig`\n",
        "  * `tfmot.quantization.keras.quantizers.Quantizer`\n",
        "  * `tfmot.quantization.keras.quantizers.LastValueQuantizer`\n",
        "  * `tfmot.quantization.keras.quantizers.MovingAverageQuantizer`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q1KI_FCcU7Yn"
      },
      "source": [
        "### Setup: DefaultDenseQuantizeConfig"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I6nPkJDRUB2G"
      },
      "source": [
        "Experimenting requires using `tfmot.quantization.keras.QuantizeConfig`, which describes how to quantize the weights, activations, and outputs of a layer.\n",
        "\n",
        "Below is an example that defines the same `QuantizeConfig` used for the `Dense` layer in the API defaults.\n",
        "\n",
        "During the forward propagation in this example, the `LastValueQuantizer` returned in `get_weights_and_quantizers` is called with `layer.kernel` as the input, producing an output. The output replaces `layer.kernel`\n",
        "in the original forward propagation of the `Dense` layer, via the logic defined in `set_quantize_weights`.  The same idea applies to the activations and outputs.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B9SWK5UQT7VQ"
      },
      "outputs": [],
      "source": [
        "LastValueQuantizer = tfmot.quantization.keras.quantizers.LastValueQuantizer\n",
        "MovingAverageQuantizer = tfmot.quantization.keras.quantizers.MovingAverageQuantizer\n",
        "\n",
        "class DefaultDenseQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):\n",
        "    # Configure how to quantize weights.\n",
        "    def get_weights_and_quantizers(self, layer):\n",
        "      return [(layer.kernel, LastValueQuantizer(num_bits=8, symmetric=True, narrow_range=False, per_axis=False))]\n",
        "\n",
        "    # Configure how to quantize activations.\n",
        "    def get_activations_and_quantizers(self, layer):\n",
        "      return [(layer.activation, MovingAverageQuantizer(num_bits=8, symmetric=False, narrow_range=False, per_axis=False))]\n",
        "\n",
        "    def set_quantize_weights(self, layer, quantize_weights):\n",
        "      # Add this line for each item returned in `get_weights_and_quantizers`\n",
        "      # , in the same order\n",
        "      layer.kernel = quantize_weights[0]\n",
        "\n",
        "    def set_quantize_activations(self, layer, quantize_activations):\n",
        "      # Add this line for each item returned in `get_activations_and_quantizers`\n",
        "      # , in the same order.\n",
        "      layer.activation = quantize_activations[0]\n",
        "\n",
        "    # Configure how to quantize outputs (may be equivalent to activations).\n",
        "    def get_output_quantizers(self, layer):\n",
        "      return []\n",
        "\n",
        "    def get_config(self):\n",
        "      return {}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8vJeoGQG9ZX0"
      },
      "source": [
        "### Quantize custom Keras layer\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YmyhI_bzWb2w"
      },
      "source": [
        "This example uses the `DefaultDenseQuantizeConfig` to quantize the `CustomLayer`.\n",
        "\n",
        "Applying the configuration is the same across\n",
        "the \"Experiment with quantization\" use cases.\n",
        " * Apply `tfmot.quantization.keras.quantize_annotate_layer` to the `CustomLayer` and pass in the `QuantizeConfig`.\n",
        " * Use\n",
        "`tfmot.quantization.keras.quantize_annotate_model` to continue to quantize the rest of the model with the API defaults.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7_rBOJdyWWEs"
      },
      "outputs": [],
      "source": [
        "quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer\n",
        "quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model\n",
        "quantize_scope = tfmot.quantization.keras.quantize_scope\n",
        "\n",
        "class CustomLayer(tf.keras.layers.Dense):\n",
        "  pass\n",
        "\n",
        "model = quantize_annotate_model(tf.keras.Sequential([\n",
        "   quantize_annotate_layer(CustomLayer(20, input_shape=(20,)), DefaultDenseQuantizeConfig()),\n",
        "   tf.keras.layers.Flatten()\n",
        "]))\n",
        "\n",
        "# `quantize_apply` requires mentioning `DefaultDenseQuantizeConfig` with `quantize_scope`\n",
        "# as well as the custom Keras layer.\n",
        "with quantize_scope(\n",
        "  {'DefaultDenseQuantizeConfig': DefaultDenseQuantizeConfig,\n",
        "   'CustomLayer': CustomLayer}):\n",
        "  # Use `quantize_apply` to actually make the model quantization aware.\n",
        "  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)\n",
        "\n",
        "quant_aware_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vnMguvVSnUqD"
      },
      "source": [
        "### Modify quantization parameters\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BLgH1aFMjTK4"
      },
      "source": [
        "**Common mistake:** quantizing the bias to fewer than 32-bits usually harms model accuracy too much.\n",
        "\n",
        "This example modifies the `Dense` layer to use 4-bits for its weights instead\n",
        "of the default 8-bits. The rest of the model continues to use API defaults.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "77jgBjccnTh6"
      },
      "outputs": [],
      "source": [
        "quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer\n",
        "quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model\n",
        "quantize_scope = tfmot.quantization.keras.quantize_scope\n",
        "\n",
        "class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):\n",
        "    # Configure weights to quantize with 4-bit instead of 8-bits.\n",
        "    def get_weights_and_quantizers(self, layer):\n",
        "      return [(layer.kernel, LastValueQuantizer(num_bits=4, symmetric=True, narrow_range=False, per_axis=False))]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x9JDKhaU3FKe"
      },
      "source": [
        "Applying the configuration is the same across\n",
        "the \"Experiment with quantization\" use cases.\n",
        " * Apply `tfmot.quantization.keras.quantize_annotate_layer` to the `Dense` layer and pass in the `QuantizeConfig`.\n",
        " * Use\n",
        "`tfmot.quantization.keras.quantize_annotate_model` to continue to quantize the rest of the model with the API defaults."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sq5mfyBF3KxV"
      },
      "outputs": [],
      "source": [
        "model = quantize_annotate_model(tf.keras.Sequential([\n",
        "   # Pass in modified `QuantizeConfig` to modify this Dense layer.\n",
        "   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),\n",
        "   tf.keras.layers.Flatten()\n",
        "]))\n",
        "\n",
        "# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:\n",
        "with quantize_scope(\n",
        "  {'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):\n",
        "  # Use `quantize_apply` to actually make the model quantization aware.\n",
        "  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)\n",
        "\n",
        "quant_aware_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bJMKgzh84CCs"
      },
      "source": [
        "### Modify parts of layer to quantize\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z3pij2uO808g"
      },
      "source": [
        "This example modifies the `Dense` layer to skip quantizing the activation. The rest of the model continues to use API defaults."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6BaaJPBR8djV"
      },
      "outputs": [],
      "source": [
        "quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer\n",
        "quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model\n",
        "quantize_scope = tfmot.quantization.keras.quantize_scope\n",
        "\n",
        "class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):\n",
        "    def get_activations_and_quantizers(self, layer):\n",
        "      # Skip quantizing activations.\n",
        "      return []\n",
        "\n",
        "    def set_quantize_activations(self, layer, quantize_activations):\n",
        "      # Empty since `get_activaations_and_quantizers` returns\n",
        "      # an empty list.\n",
        "      return"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2OkqHX5r2nT7"
      },
      "source": [
        "Applying the configuration is the same across\n",
        "the \"Experiment with quantization\" use cases.\n",
        " * Apply `tfmot.quantization.keras.quantize_annotate_layer` to the `Dense` layer and pass in the `QuantizeConfig`.\n",
        " * Use\n",
        "`tfmot.quantization.keras.quantize_annotate_model` to continue to quantize the rest of the model with the API defaults."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ln9MDIZJ2n3F"
      },
      "outputs": [],
      "source": [
        "model = quantize_annotate_model(tf.keras.Sequential([\n",
        "   # Pass in modified `QuantizeConfig` to modify this Dense layer.\n",
        "   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),\n",
        "   tf.keras.layers.Flatten()\n",
        "]))\n",
        "\n",
        "# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:\n",
        "with quantize_scope(\n",
        "  {'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):\n",
        "  # Use `quantize_apply` to actually make the model quantization aware.\n",
        "  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)\n",
        "\n",
        "quant_aware_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yD0sIR6tmmRx"
      },
      "source": [
        "### Use custom quantization algorithm\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I4onhF-H1zsn"
      },
      "source": [
        "The `tfmot.quantization.keras.quantizers.Quantizer` class is a callable that\n",
        "can apply any algorithm to its inputs.\n",
        "\n",
        "In this example, the inputs are the weights, and we apply the math in the\n",
        "`FixedRangeQuantizer` \\_\\_call\\_\\_ function to the weights. Instead of the original\n",
        "weights values, the output of the\n",
        "`FixedRangeQuantizer` is now passed to whatever would have used the weights."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Jt8UioZH49QV"
      },
      "outputs": [],
      "source": [
        "quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer\n",
        "quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model\n",
        "quantize_scope = tfmot.quantization.keras.quantize_scope\n",
        "\n",
        "class FixedRangeQuantizer(tfmot.quantization.keras.quantizers.Quantizer):\n",
        "  \"\"\"Quantizer which forces outputs to be between -1 and 1.\"\"\"\n",
        "\n",
        "  def build(self, tensor_shape, name, layer):\n",
        "    # Not needed. No new TensorFlow variables needed.\n",
        "    return {}\n",
        "\n",
        "  def __call__(self, inputs, training, weights, **kwargs):\n",
        "    return tf.keras.backend.clip(inputs, -1.0, 1.0)\n",
        "\n",
        "  def get_config(self):\n",
        "    # Not needed. No __init__ parameters to serialize.\n",
        "    return {}\n",
        "\n",
        "\n",
        "class ModifiedDenseQuantizeConfig(DefaultDenseQuantizeConfig):\n",
        "    # Configure weights to quantize with 4-bit instead of 8-bits.\n",
        "    def get_weights_and_quantizers(self, layer):\n",
        "      # Use custom algorithm defined in `FixedRangeQuantizer` instead of default Quantizer.\n",
        "      return [(layer.kernel, FixedRangeQuantizer())]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lu5ZeJ_Y2UxW"
      },
      "source": [
        "Applying the configuration is the same across\n",
        "the \"Experiment with quantization\" use cases.\n",
        " * Apply `tfmot.quantization.keras.quantize_annotate_layer` to the `Dense` layer and pass in the `QuantizeConfig`.\n",
        " * Use\n",
        "`tfmot.quantization.keras.quantize_annotate_model` to continue to quantize the rest of the model with the API defaults."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ItC_3mwT2U87"
      },
      "outputs": [],
      "source": [
        "model = quantize_annotate_model(tf.keras.Sequential([\n",
        "   # Pass in modified `QuantizeConfig` to modify this `Dense` layer.\n",
        "   quantize_annotate_layer(tf.keras.layers.Dense(20, input_shape=(20,)), ModifiedDenseQuantizeConfig()),\n",
        "   tf.keras.layers.Flatten()\n",
        "]))\n",
        "\n",
        "# `quantize_apply` requires mentioning `ModifiedDenseQuantizeConfig` with `quantize_scope`:\n",
        "with quantize_scope(\n",
        "  {'ModifiedDenseQuantizeConfig': ModifiedDenseQuantizeConfig}):\n",
        "  # Use `quantize_apply` to actually make the model quantization aware.\n",
        "  quant_aware_model = tfmot.quantization.keras.quantize_apply(model)\n",
        "\n",
        "quant_aware_model.summary()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "Tce3stUlHN0L"
      ],
      "name": "training_comprehensive_guide.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
